ベイズ統計モデリングとMCMC

伊東宏樹

2024-11-23

内容

  • ベイズ統計モデリング

  • MCMC(マルコフ連鎖モンテカルロ)法

    • NIMBLEを使った例

    • Stanを使った例

統計モデルとは

  • 変数間の関係を何らかの確率分布を使って(パーツとして)記述して作成したモデル(模型)
  • 作成したモデルでシステムを説明したり、予測したりする

ベイズの定理

\[ P(Y \mid X) = \frac{P(X \mid Y)P(Y)}{P(X)} \]

\[ = \frac{P(X \mid Y)P(Y)}{P(X \mid Y)P(Y)+P(X \mid \overline{Y})P(\overline{Y})} \]

例題

1000人に1人がかかる病気があるとする。

検査をすると、この病気にかかっている場合には99%の確率で陽性となる。ただし、かかっていなくても5%の確率で誤って陽性になる。

ある人が検査を受けて陽性になった。このとき実際にこの人がこの病気にかかっている確率は何パーセントか。

\[ \frac{0.99 \times 0.001}{0.99 \times 0.001 + 0.05 \times 0.999} = 0.01943463 \]

実際にこの病気にかかっている確率はおよそ2%

ベイズ推定

  • 事前確率を、得られたデータで更新していく

  • 確率分布を推定するときは、事前分布→事後分布

MCMC(マルコフ連鎖モンテカルロ)法とは

  • モデルのパラメータを推定する手法

  • MCMC = MC(マルコフ連鎖) + MC(モンテカルロ)

マルコフ連鎖

1期前の状態にのみ依存する確率変数列

例: ランダムウオーク

モンテカルロ法

乱数を使った推定法

例: 円周率を求める

MCMC

  • ベイズ統計モデルで、複雑な統計モデルのパラメータ推定に使われる

  • 乱数を使って、一定のアルゴリズム(Metropolis-Hastings法, Gibbsサンプリング, Hamiltonian Monte Carlo法など)により、事後分布からサンプリングしたと見なせるマルコフ連鎖を生成する

MCMCのソフトウェア

など

いずれもRとは別のモデル記述言語で、モデルを記述する

NIMBLEを使った統計モデリングの例

データ

群ごとに切片が異なるが、群内では傾きはだいたい2くらいでどれも同程度

群を無視すると

lm(Y ~ X, data = df) |> summary()

Call:
lm(formula = Y ~ X, data = df)

Residuals:
    Min      1Q  Median      3Q     Max 
-2.2448 -0.8732 -0.1221  0.8737  2.6287 

Coefficients:
            Estimate Std. Error t value Pr(>|t|)    
(Intercept)   2.6253     0.4879   5.381 2.63e-07 ***
X             1.3684     0.1230  11.124  < 2e-16 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 1.161 on 158 degrees of freedom
Multiple R-squared:  0.4392,    Adjusted R-squared:  0.4357 
F-statistic: 123.8 on 1 and 158 DF,  p-value: < 2.2e-16

傾きを過小評価してしまった

混合効果モデル

固定効果+変量効果

  • 固定効果: 説明変数による目的変数への効果

  • 変量効果(ランダム効果): 群の違いによる効果

    • 通常、変量効果は正規分布にしたがうとする

    • ベイズ統計モデリングでは、階層事前分布を設定→階層ベイズモデル

NIMBLEモデル

BUGS言語で統計モデルを記述

code <- nimbleCode({
  for (n in 1:N) {
    mu[n] <- alpha + beta * X[n] + epsilon[Gind[n]]
    Y[n] ~ dnorm(mu[n], tau[1])
  }
  for (g in 1:G) {
    epsilon[g] ~ dnorm(0, tau[2])
  }
  alpha ~ dnorm(0, 1e-4)
  beta ~ dnorm(0, 1e-4)
  for (i in 1:2) {
    tau[i] <- 1 / (sigma[i] * sigma[i])
    sigma[i] ~ dunif(0, 100)
  }
})

コンパイル・実行

G <- length(levels(df$Group))
out <- nimbleMCMC(code = code,
                  constants = list(N = nrow(df),
                                   G = G,
                                   Gind = as.numeric(df$Group)),
                  data = list(Y = df$Y, X = df$X),
                  inits = list(alpha = -2, beta = -2,
                               epsilon = rep(0, G),
                               sigma = c(4, 2)),
                  niter = 500, nburnin = 0,
                  samplesAsCodaMCMC = TRUE)

結果

betaのマルコフ連鎖の軌跡

burn-in

初期値の影響が残っている部分は捨てる

サンプリング

out <- nimbleMCMC(code = code,
                  constants = list(N = nrow(df), G = G,
                                   Gind = as.numeric(df$Group)),
                  data = list(Y = df$Y, X = df$X),
                  inits = function() {
                    list(alpha = runif(1, -2, 2),
                         beta = runif(1, -2, 2),
                    epsilon = runif(G, -2, 2),
                    sigma = runif(2, 0, 2))},
                  nchains = 3, niter = 12000, nburnin = 2000,
                  samplesAsCodaMCMC = TRUE)

traceplot (alpha)

マルコフ連鎖の軌跡プロット(codaパッケージのtraceplot関数を使用)

traceplot(out[, "alpha"])

traceplot (sigma[1])

traceplot(out[, "sigma[1]"])

R-hat

MCMC計算が収束したかどうかの指標値。1.1以下ならOKとする場合が多い。

gelman.diag(out)
Potential scale reduction factors:

         Point est. Upper C.I.
alpha          1.02       1.05
beta           1.02       1.06
sigma[1]       1.00       1.00
sigma[2]       1.01       1.02

Multivariate psrf

1.01

結果

結果の要約

summary(out)

Iterations = 1:10000
Thinning interval = 1 
Number of chains = 3 
Sample size per chain = 10000 

1. Empirical mean and standard deviation for each variable,
   plus standard error of the mean:

           Mean      SD  Naive SE Time-series SE
alpha    0.8541 0.68319 0.0039444      0.0472391
beta     1.8285 0.15104 0.0008720      0.0092044
sigma[1] 1.0015 0.05883 0.0003397      0.0007483
sigma[2] 0.9245 0.34924 0.0020163      0.0080288

2. Quantiles for each variable:

            2.5%    25%    50%   75% 97.5%
alpha    -0.5328 0.3963 0.8588 1.329 2.163
beta      1.5270 1.7255 1.8311 1.935 2.114
sigma[1]  0.8941 0.9604 0.9986 1.040 1.125
sigma[2]  0.4594 0.6876 0.8548 1.078 1.818

密度グラフ

densplot(out[, "beta"])

Stanを使った統計モデリング

RからStanを使う方法

  • rstanパッケージ
  • cmdstanrパッケージ

今回は前者を使用

Stanのモデル

Stanで記述した同等のモデル。各パラメータの事前分布は弱情報事前分布とした。

lme.stan
/*
  stan model for linear mixed effects model
*/

data {
  int<lower=0> N;
  int<lower=0> G;
  array[N] int<lower=1,upper=G> Gind;
  vector[N] X;
  vector[N] Y;
}

parameters {
  real alpha;
  real beta;
  vector[G] epsilon;
  vector<lower=0>[2] sigma;
}

transformed parameters {
   vector[N] mu;
   
   for (n in 1:N)
     mu[n] = alpha + beta * X[n] + epsilon[Gind[n]];
}

model {
  Y ~ normal(mu, sigma[1]);
  epsilon ~ normal(0, sigma[2]);
  alpha ~ normal(0, 10);
  beta ~ normal(0, 10);
  sigma ~ normal(0, 5);
}

実行・サンプリング

fit <- stan(file = file.path("model", "lme.stan"),
            data = list(N = nrow(df),
                        G = G,
                        Gind = as.numeric(df$Group),
                        X = df$X, Y = df$Y),
            pars = c("alpha", "beta", "sigma"),
            iter = 2000, warmup = 1000)

結果

各パラメータの事後分布の要約

                mean     se_mean        sd        2.5%         25%         50%
alpha      0.7308328 0.024709076 0.6811684  -0.6735333   0.2738453   0.7407015
beta       1.8545651 0.004053186 0.1478083   1.5546949   1.7555509   1.8592507
sigma[1]   1.0002268 0.001228641 0.0590671   0.8948997   0.9599595   0.9966786
sigma[2]   0.9393150 0.012137218 0.3704167   0.4832510   0.6973019   0.8557121
lp__     -81.9555096 0.081620377 2.5565093 -87.7553986 -83.5292805 -81.6296027
                75%      97.5%     n_eff     Rhat
alpha      1.195513   2.059737  759.9691 1.004501
beta       1.953983   2.136503 1329.8567 1.003722
sigma[1]   1.037070   1.126986 2311.2193 1.003092
sigma[2]   1.087273   1.889887  931.4144 1.005365
lp__     -80.056135 -77.932245  981.0645 1.002013

参考文献